library(here)
library(tidyverse)
library(tidylog)
library(irr)
library(tinytable)

# =============================================================================
# Functions for LLM Annotation Quality Assessment
# =============================================================================

#' Calculate comprehensive inter-coder reliability metrics
#'
#' @param llm_annotations Vector of LLM annotations (0/1)
#' @param human_annotations Vector of human annotations (0/1, treated as ground truth)
#' @param na.rm Logical, whether to remove NA values (default TRUE)
#' @return A list containing all reliability metrics
calc_annotation_quality <- function(llm_annotations, human_annotations, na.rm = TRUE) {
  # Convert to numeric
  llm <- as.numeric(llm_annotations)
  human <- as.numeric(human_annotations)

  # Handle NAs
  if (na.rm) {
    complete <- complete.cases(llm, human)
    llm <- llm[complete]
    human <- human[complete]
  }

  n_total <- length(llm)

  # --- Confusion Matrix Components ---
  # Treating human as ground truth
  tp <- sum(llm == 1 & human == 1)  # True positive

  tn <- sum(llm == 0 & human == 0)  # True negative
  fp <- sum(llm == 1 & human == 0)  # False positive
  fn <- sum(llm == 0 & human == 1)  # False negative

  cont_table <- matrix(c(tn, fp, fn, tp), nrow = 2,
                       dimnames = list(LLM = c("0", "1"), Human = c("0", "1")))

  # --- Percentage Agreement ---
  pct_agreement <- (tp + tn) / n_total

  # --- Cohen's Kappa ---
  # Accounts for chance agreement (using irr package for robust calculation)
  kappa_result <- tryCatch({
    irr::kappa2(cbind(llm, human))
  }, error = function(e) list(value = NA, p.value = NA))

  # --- Precision, Recall, F1 (treating human as ground truth) ---
  precision <- if ((tp + fp) > 0) tp / (tp + fp) else NA
  recall <- if ((tp + fn) > 0) tp / (tp + fn) else NA
  f1_score <- if (!is.na(precision) && !is.na(recall) && (precision + recall) > 0) {
    2 * precision * recall / (precision + recall)
  } else NA

  # --- Matthews Correlation Coefficient (MCC) ---
  # Good for imbalanced data
  mcc_denom <- sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
  mcc <- if (mcc_denom > 0) (tp * tn - fp * fn) / mcc_denom else NA

  # --- Krippendorff's Alpha ---
  kripp_result <- tryCatch({
    irr::kripp.alpha(rbind(llm, human), method = "nominal")
  }, error = function(e) list(value = NA))

  # --- Prevalence and Bias ---
  prevalence_human <- mean(human)
  prevalence_llm <- mean(llm)
  bias <- prevalence_llm - prevalence_human

  # Return results
  list(
    n_total = n_total,
    confusion_matrix = cont_table,
    tp = tp, tn = tn, fp = fp, fn = fn,
    pct_agreement = pct_agreement,
    cohens_kappa = kappa_result$value,
    kappa_pvalue = kappa_result$p.value,
    precision = precision,
    recall = recall,
    f1_score = f1_score,
    mcc = mcc,
    krippendorff_alpha = kripp_result$value,
    prevalence_human = prevalence_human,
    prevalence_llm = prevalence_llm,
    bias = bias
  )
}

#' Print annotation quality metrics in a formatted way
#'
#' @param results Output from calc_annotation_quality function
#' @param var_name Optional name of the variable being assessed
print_annotation_quality <- function(results, var_name = NULL) {
  if (!is.null(var_name)) {
    message("\n", paste(rep("=", 60), collapse = ""))
    message("Variable: ", var_name)
    message(paste(rep("=", 60), collapse = ""))
  }

  message(sprintf("\nSample Size: %d", results$n_total))

  message("\n--- Confusion Matrix (LLM vs Human) ---")
  message("              Human")
  message("         |  0   |  1   |")
  message(sprintf("LLM   0  | %4d | %4d |", results$tn, results$fn))
  message(sprintf("      1  | %4d | %4d |", results$fp, results$tp))

  message("\n--- Inter-Coder Reliability ---")
  message(sprintf("Percent Agreement:     %.1f%%", results$pct_agreement * 100))
  message(sprintf("Cohen's Kappa:         %.3f (p = %.4f)",
                  results$cohens_kappa, results$kappa_pvalue))
  message(sprintf("Krippendorff's Alpha:  %.3f", results$krippendorff_alpha))

  message("\n--- Classification Metrics (Human as Ground Truth) ---")
  message(sprintf("Precision:             %.3f", results$precision))
  message(sprintf("Recall:                %.3f", results$recall))
  message(sprintf("F1 Score:              %.3f", results$f1_score))
  message(sprintf("Matthews Corr Coef:    %.3f", results$mcc))

  message("\n--- Prevalence & Bias ---")
  message(sprintf("Human prevalence (1):  %.1f%%", results$prevalence_human * 100))
  message(sprintf("LLM prevalence (1):    %.1f%%", results$prevalence_llm * 100))
  message(sprintf("Bias (LLM - Human):    %.1f%%", results$bias * 100))

  # Interpretation
  message("\n--- Interpretation ---")
  if (!is.na(results$cohens_kappa)) {
    kappa_interp <- case_when(
      results$cohens_kappa < 0 ~ "Poor (less than chance)",
      results$cohens_kappa < 0.20 ~ "Slight",
      results$cohens_kappa < 0.40 ~ "Fair",
      results$cohens_kappa < 0.60 ~ "Moderate",
      results$cohens_kappa < 0.80 ~ "Substantial",
      TRUE ~ "Almost Perfect"
    )
    message(sprintf("Kappa interpretation:  %s", kappa_interp))
  }
}

#' Run annotation quality assessment for multiple variables
#'
#' @param data Data frame containing both LLM and human annotations
#' @param llm_vars Character vector of LLM variable names
#' @param human_vars Character vector of corresponding human variable names
#' @return Data frame with quality metrics for each variable
run_quality_assessment <- function(data, llm_vars, human_vars) {
  results_list <- list()

  for (i in seq_along(llm_vars)) {
    llm_var <- llm_vars[i]
    human_var <- human_vars[i]

    result <- calc_annotation_quality(
      data[[llm_var]],
      data[[human_var]]
    )

    results_list[[llm_var]] <- tibble(
      variable = llm_var,
      n = result$n_total,
      pct_agreement = result$pct_agreement,
      cohens_kappa = result$cohens_kappa,
      kappa_pvalue = result$kappa_pvalue,
      precision = result$precision,
      recall = result$recall,
      f1_score = result$f1_score,
      mcc = result$mcc,
      krippendorff_alpha = result$krippendorff_alpha,
      prevalence_human = result$prevalence_human,
      prevalence_llm = result$prevalence_llm,
      bias = result$bias
    )

    print_annotation_quality(result, llm_var)
  }

  bind_rows(results_list)
}

# =============================================================================
# Load and Prepare Data
# =============================================================================

# Load LLM-annotated statements
g <- read_csv(here("data", "output", "statements_analysis.csv"), show_col_types = FALSE)

# Load human annotations
h <- read_csv(here("data", "input", "credit", "annotation_qc", "aditya_qc.csv"),
              show_col_types = FALSE)
qc <- subset(h, select = c(id, statement, gives_credit...14:credit_governor))
qc <- h %>%
  rename(
    gives_credit_h = gives_credit...14,
    credit_biden_h = credit_biden,
    credit_governor_h = credit_governor
  )
qc <- subset(qc, select = c(id, statement, gives_credit_h, credit_biden_h, credit_governor_h))

# Join statements and annotations
s <- left_join(g, qc, by = c("id", "statement"))

# =============================================================================
# Run Quality Assessment
# =============================================================================

# Define variable pairs (LLM var, Human var)
llm_vars <- c("gives_credit", "credit_biden", "credit_governor")
human_vars <- c("gives_credit_h", "credit_biden_h", "credit_governor_h")

s <- filter(s, !is.na(gives_credit_h))
dim(s)

# Run comprehensive assessment
quality_results <- run_quality_assessment(s, llm_vars, human_vars)

# Print summary table
message("\n", paste(rep("=", 80), collapse = ""))
message("SUMMARY TABLE")
message(paste(rep("=", 80), collapse = ""))
print(quality_results)

# Save results
write_csv(quality_results, here("data", "output", "annotation_quality_results.csv"))
message("\nResults saved to: data/output/annotation_quality_results.csv")

# =============================================================================
# Examine Disagreements
# =============================================================================

message("\n", paste(rep("=", 80), collapse = ""))
message("DISAGREEMENT ANALYSIS")
message(paste(rep("=", 80), collapse = ""))

# Find cases where LLM and human disagree
disagreements <- s %>%
  filter(
    gives_credit != gives_credit_h |
      credit_biden != credit_biden_h |
      credit_governor != credit_governor_h
  ) %>%
  select(id, statement, gives_credit, gives_credit_h,
         credit_biden, credit_biden_h,
         credit_governor, credit_governor_h)

message(sprintf("\nTotal disagreements: %d out of %d cases (%.1f%%)",
                nrow(disagreements), nrow(s), 100 * nrow(disagreements) / nrow(s)))

# =============================================================================
# Measurement Error Bounds for Governor Credit
# =============================================================================

message("\n", paste(rep("=", 80), collapse = ""))
message("MEASUREMENT ERROR BOUNDS: Governor Credit")
message(paste(rep("=", 80), collapse = ""))

# From validation sample
precision_gov <- 0.65
recall_gov <- 0.85

# Calculate misclassification rates
fpr_given_positive <- 1 - precision_gov  # P(LLM=1 | True=0) among LLM positives
fnr <- 1 - recall_gov                     # P(LLM=0 | True=1)

message(sprintf("\nMisclassification rates:"))
message(sprintf("  False positive rate (among LLM positives): %.1f%%", fpr_given_positive * 100))
message(sprintf("  False negative rate (among true positives): %.1f%%", fnr * 100))

# Load full dataset to get LLM prevalence
g_full <- read_csv(here("data", "output", "statements_analysis.csv"), show_col_types = FALSE)
observed_prev <- mean(g_full$credit_governor, na.rm = TRUE)

message(sprintf("\nObserved (LLM) prevalence of Governor credit: %.1f%%", observed_prev * 100))

# Estimate true prevalence using misclassification correction
# If: Observed = Recall * True + FPR * (1 - True)
# We need the full FPR (not just among positives). From validation:
gov_result <- calc_annotation_quality(s$credit_governor, s$credit_governor_h)
fp <- gov_result$fp
tn <- gov_result$tn
fn <- gov_result$fn
tp <- gov_result$tp

# Full false positive rate: P(LLM=1 | True=0)
fpr_full <- fp / (fp + tn)
# Sensitivity (recall): P(LLM=1 | True=1)
sensitivity <- tp / (tp + fn)

message(sprintf("\nFrom confusion matrix (N=%d):", gov_result$n_total))
message(sprintf("  TP=%d, FP=%d, TN=%d, FN=%d", tp, fp, tn, fn))
message(sprintf("  Sensitivity (recall): %.2f", sensitivity))
message(sprintf("  False positive rate: %.2f", fpr_full))

# =============================================================================
# Statistical Tests for Directional Bias in Governor Credit
# =============================================================================

message("\n", paste(rep("=", 80), collapse = ""))
message("STATISTICAL TESTS FOR LLM BIAS (Governor Credit)")
message(paste(rep("=", 80), collapse = ""))

# --- 1. McNemar Test for Asymmetric Disagreement ---
# Tests whether FP ≠ FN (i.e., are disagreements systematically in one direction?)
message("\n--- McNemar Test (Asymmetric Disagreement) ---")
message("H0: False positives = False negatives (symmetric errors)")
message("H1: Errors are asymmetric (LLM systematically over- or under-attributes)")

mcnemar_result <- mcnemar.test(matrix(c(tn, fp, fn, tp), nrow = 2))
message(sprintf("\nMcNemar chi-squared: %.2f", mcnemar_result$statistic))
message(sprintf("p-value: %.4f", mcnemar_result$p.value))
message(sprintf("FP = %d, FN = %d", fp, fn))

if (mcnemar_result$p.value < 0.05) {
  if (fp > fn) {
    message("Result: LLM SIGNIFICANTLY over-attributes Governor credit (FP > FN)")
  } else {
    message("Result: LLM SIGNIFICANTLY under-attributes Governor credit (FN > FP)")
  }
} else {
  message("Result: No significant asymmetry in disagreements")
}

# --- 2. Bias Estimate with Confidence Interval ---
message("\n--- Bias Estimate with 95% CI ---")

# Point estimate of bias (in validation sample)
llm_prev_val <- mean(s$credit_governor, na.rm = TRUE)
human_prev_val <- mean(s$credit_governor_h, na.rm = TRUE)
bias_point <- llm_prev_val - human_prev_val

message(sprintf("In validation sample (N=%d):", nrow(s)))
message(sprintf("  LLM prevalence:   %.1f%%", llm_prev_val * 100))
message(sprintf("  Human prevalence: %.1f%%", human_prev_val * 100))
message(sprintf("  Bias (LLM - Human): %.1f percentage points", bias_point * 100))

# Bootstrap CI for the bias
set.seed(42)
n_boot <- 1000
boot_bias <- numeric(n_boot)

for (b in 1:n_boot) {
  idx <- sample(1:nrow(s), nrow(s), replace = TRUE)
  boot_llm <- mean(s$credit_governor[idx], na.rm = TRUE)
  boot_human <- mean(s$credit_governor_h[idx], na.rm = TRUE)
  boot_bias[b] <- boot_llm - boot_human
}

bias_ci <- quantile(boot_bias, c(0.025, 0.975))
bias_se <- sd(boot_bias)

message(sprintf("\nBootstrap results (B=%d):", n_boot))
message(sprintf("  Bias estimate: %.1f pp (SE = %.1f pp)", bias_point * 100, bias_se * 100))
message(sprintf("  95%% CI: [%.1f, %.1f] percentage points", bias_ci[1] * 100, bias_ci[2] * 100))

if (bias_ci[1] > 0) {
  message("  Interpretation: LLM significantly OVER-attributes Governor credit")
} else if (bias_ci[2] < 0) {
  message("  Interpretation: LLM significantly UNDER-attributes Governor credit")
} else {
  message("  Interpretation: Bias not significantly different from zero")
}

# --- 3. Summary for Paper ---
message("\n--- Summary Statement for Paper ---")
message(sprintf(
  "The LLM over-attributes Governor credit by %.1f percentage points (95%% CI: [%.1f, %.1f]; McNemar p = %.3f).",
  bias_point * 100, bias_ci[1] * 100, bias_ci[2] * 100, mcnemar_result$p.value
))

# Correct for misclassification to estimate true prevalence
# Observed = Sensitivity * True + FPR * (1 - True)
# Solving: True = (Observed - FPR) / (Sensitivity - FPR)
if (sensitivity != fpr_full) {
  true_prev_estimate <- (observed_prev - fpr_full) / (sensitivity - fpr_full)
  true_prev_estimate <- max(0, min(1, true_prev_estimate))  # Bound to [0,1]
  
  message(sprintf("\nEstimated TRUE prevalence of Governor credit: %.1f%%", true_prev_estimate * 100))
  message(sprintf("Observed prevalence is inflated by: %.1f percentage points", 
                  (observed_prev - true_prev_estimate) * 100))
  message(sprintf("Inflation ratio: %.2fx", observed_prev / true_prev_estimate))
}

# Bounds on regression coefficients (attenuation)
# Classical measurement error in binary X attenuates coefficients
# Attenuation factor ≈ (Sensitivity + Specificity - 1) = (TPR + TNR - 1)
specificity <- tn / (fp + tn)
attenuation_factor <- sensitivity + specificity - 1

message(sprintf("\n--- Regression Coefficient Bounds ---"))
message(sprintf("Specificity: %.2f", specificity))
message(sprintf("Attenuation factor (Youden's J): %.2f", attenuation_factor))
message(sprintf("\nInterpretation:"))
message(sprintf("  If true coefficient is β, observed coefficient ≈ %.2f * β", attenuation_factor))
message(sprintf("  To recover true effect: multiply observed coefficient by %.2f", 1/attenuation_factor))
message(sprintf("\nExample: If observed β = 0.10, true β is approximately %.2f to %.2f",
                0.10 / attenuation_factor * 0.9, 0.10 / attenuation_factor * 1.1))

# =============================================================================
# Governor Credit Bounds by Actor Type
# =============================================================================

message("\n", paste(rep("=", 80), collapse = ""))
message("GOVERNOR CREDIT BOUNDS BY ACTOR")
message(paste(rep("=", 80), collapse = ""))

# Function to correct observed prevalence for misclassification
correct_prevalence <- function(observed, sensitivity, fpr) {
  if (sensitivity == fpr) return(NA)
  corrected <- (observed - fpr) / (sensitivity - fpr)
  return(max(0, min(1, corrected)))  # Bound to [0,1]
}

# Standardize actor names
g_full <- g_full %>%
  mutate(
    actor_clean = case_when(
      actor == "senator_1_text" | actor == "senator_2_text" ~ "Senator",
      actor == "governor_text" ~ "Governor",
      actor == "company_text" ~ "Company",
      actor == "representative_text" ~ "U.S. Rep",
      actor == "biden_text" ~ "Biden",
      TRUE ~ actor
    )
  )

# Calculate observed and corrected prevalence by actor
# CONDITIONAL on gives_credit == 1 (among statements that give credit)
actor_bounds <- g_full %>%
  filter(gives_credit == 1) %>%  # Only statements that give credit
  group_by(actor_clean) %>%
  summarize(
    n = n(),
    n_credit_gov = sum(credit_governor, na.rm = TRUE),
    observed_prev = mean(credit_governor, na.rm = TRUE),
    .groups = "drop"
  ) %>%
  mutate(
    corrected_prev = sapply(observed_prev, correct_prevalence, 
                            sensitivity = sensitivity, fpr = fpr_full),
    bias_pp = (observed_prev - corrected_prev) * 100,
    lower_bound = pmax(0, corrected_prev - 0.05),  # Conservative bounds
    upper_bound = pmin(1, corrected_prev + 0.05)
  ) %>%
  arrange(desc(observed_prev))

message("\nGovernor credit prevalence by actor, AMONG STATEMENTS THAT GIVE CREDIT:\n")
message("(i.e., P(credit_governor | gives_credit = 1, actor))\n")
message(sprintf("%-12s %6s %10s %10s %10s %15s", 
                "Actor", "N", "Observed", "Corrected", "Bias (pp)", "Bounds"))
message(paste(rep("-", 70), collapse = ""))

for (i in 1:nrow(actor_bounds)) {
  row <- actor_bounds[i, ]
  message(sprintf("%-12s %6d %9.1f%% %9.1f%% %+9.1f %6.1f%% - %.1f%%",
                  row$actor_clean,
                  row$n,
                  row$observed_prev * 100,
                  row$corrected_prev * 100,
                  row$bias_pp,
                  row$lower_bound * 100,
                  row$upper_bound * 100))
}

message("\n")
message("Notes:")
message(sprintf("  - Correction uses validation sample sensitivity=%.2f, FPR=%.2f", sensitivity, fpr_full))
message("  - Bounds add ±5pp uncertainty around corrected estimate")
message("  - Positive bias means LLM over-attributes Governor credit")

# Save bounds table
write_csv(actor_bounds, here("data", "output", "governor_credit_bounds_by_actor.csv"))
message("\nBounds saved to: data/output/governor_credit_bounds_by_actor.csv")

# =============================================================================
# Generate LaTeX Table for Supplemental Appendix
# =============================================================================

message("\n", paste(rep("=", 80), collapse = ""))
message("GENERATING LATEX TABLE FOR ANNOTATION QUALITY")
message(paste(rep("=", 80), collapse = ""))

# Create formatted table data
quality_table <- quality_results %>%
  mutate(
    Variable = case_when(
      variable == "gives_credit" ~ "Gives Credit",
      variable == "credit_biden" ~ "Biden Credit",
      variable == "credit_governor" ~ "Governor Credit",
      TRUE ~ variable
    ),
    Agreement = sprintf("%.1f\\%%", pct_agreement * 100),
    `Cohen's $\\kappa$` = sprintf("%.3f", cohens_kappa),
    Precision = sprintf("%.3f", precision),
    Recall = sprintf("%.3f", recall),
    `F1 Score` = sprintf("%.3f", f1_score)
  ) %>%
  select(Variable, N = n, Agreement, `Cohen's $\\kappa$`, Precision, Recall, `F1 Score`)

# Define table notes with interpretation guidance and bounding test results
table_notes <- paste0(

  "\\textit{Notes}: Quality metrics comparing LLM annotations (GPT-4) to human ground truth ",
  "annotations for $N=100$ randomly sampled statements. ",
  "Agreement is raw percent agreement. ",
  "Cohen's $\\kappa$ accounts for chance agreement; values $>0.6$ indicate substantial agreement. ",
  "Precision is P(Human=1 $|$ LLM=1); Recall is P(LLM=1 $|$ Human=1). ",
  "F1 is the harmonic mean of Precision and Recall. ",
  sprintf(
    "For Governor Credit, McNemar's test for asymmetric errors yields $p = %.3f$, ",
    mcnemar_result$p.value
  ),
  sprintf(
    "with the LLM over-attributing credit by %.1f percentage points (95\\%% bootstrap CI: [%.1f, %.1f]).",
    bias_point * 100, bias_ci[1] * 100, bias_ci[2] * 100
  )
)

# Generate tinytable
tt <- tinytable::tt(
  quality_table,
  caption = "LLM Annotation Quality: Comparison to Human Ground Truth \\label{tab:annotation_quality}",
  notes = table_notes
) |>
  tinytable::style_tt(j = 2:7, align = "c") |>
  tinytable::style_tt(j = 1, align = "l") |>
  tinytable::theme_tt("tabular", style = "tabular") |>
  tinytable::format_tt(escape = FALSE) |>
  tinytable::theme_tt("placement", latex_float = "H")

# Save to output directory
output_file <- here("output", "pnas", "tables", "tab_S39_annotation_quality.tex")
tinytable::save_tt(tt, output = output_file, overwrite = TRUE)

message(sprintf("\nLaTeX table saved to: %s", output_file))